FastChat Training Script Code Analysis - Train.py 【FastChat Series Part 1】

In this article, we delve into the train.py script of FastChat (https://github.com/lm-sys/FastChat) (https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py), a key component for training and optimizing large language models (LLMs). FastChat is an advanced open-source platform focused on developing, deploying, and evaluating chatbots based on LLMs. The platform not only supports top-tier models like Vicuna and MT-Bench but also includes a distributed multi-model service system equipped with a Web UI and RESTful API compatible with OpenAI, enabling efficient training and evaluation of models.

We provide a detailed analysis of the train.py script’s source code. This script is a training script for natural language processing models based on the transformers library, covering critical steps such as data preprocessing, model training, and saving. Our goal is to offer a detailed explanation of each class and function in train.py, including their functionality and role in the overall training process.

1. Importing Modules

1. Built-in Modules

These are standard library modules that come with Python and don’t require additional installation.

1
from dataclasses import dataclass, field

Imports Python’s dataclasses module for creating classes with default values.

1
import json

Imports the json module for handling JSON format data.

1
import math

Imports the math module for mathematical operations.

1
import pathlib

Imports the pathlib module for handling file paths.

1
from typing import Dict, Optional, Sequence

Imports the typing module for type annotations.

1
import numpy as np

2. Dependency Libraries

These are external libraries typically installed via a package manager like pip.
Imports the numpy library, commonly used for scientific computing.

1
import torch

Imports PyTorch, a popular deep learning framework.

1
from torch.utils.data import Dataset

Imports Dataset from torch for creating custom datasets.

1
import transformers

Imports the transformers library, a popular natural language processing library.

1
from transformers import Trainer

Imports Trainer from transformers for training models.

1
from transformers.trainer_pt_utils import LabelSmoother

Imports LabelSmoother from transformers for label smoothing.

3. Project-Specific Functions

These are functions or classes custom-implemented in the Fast Chat project.

1
from fastchat.conversation import SeparatorStyle

Imports SeparatorStyle from the fastchat package for defining conversation separator styles. The SeparatorStyle class is an enumeration class created using Python’s enum module, defining a series of separator styles. Enumerations are a programming concept used to define a named set of constants, making code clearer and more maintainable.

In the SeparatorStyle class, each member represents a specific style of separator. These styles are often used in text processing, especially in scenarios where different sections or elements need to be distinguished. For instance, in handling dialog or textual data, different methods might be needed to differentiate between user input and machine responses.

Regarding the use of the auto() function:

  • auto() is a special function provided by Python’s enum module. It automatically assigns a unique value to each member in an enumeration class.
  • Without using auto(), you would need to manually assign a unique value to each enumeration member. auto() simplifies this process by letting Python handle the assignment of these values automatically.
  • The values assigned by auto() are usually integers, starting from 1 and increasing sequentially.

In the case of the SeparatorStyle class, auto() is used to automatically assign a unique integer value to each type of separator style. For example, ADD_COLON_SINGLE, ADD_COLON_TWO, etc., will be given different integer values.

The names of each enumeration member (such as ADD_COLON_SINGLE, NO_COLON_SINGLE, etc.) typically describe the characteristics of that separator style. For instance, ADD_COLON_SINGLE might represent adding a colon as a separator after a certain element, whereas NO_COLON_SINGLE means no colon is added.

This approach makes referencing and handling these separator styles in the code more convenient and clear. For example, different separator styles can be chosen based on different scenarios or requirements without having to remember their specific values.

1
from fastchat.model.model_adapter import get_conversation_template

Imports get_conversation_template from the fastchat package for obtaining conversation templates. In this code segment, the call logic primarily involves obtaining the default conversation template for a specific model. The call chain is as follows:

  1. Starting Call - get_conversation_template(model_path: str)

    • This function is the starting point of the call chain. It accepts a parameter model_path, specifying the path of the model.
    • The purpose of this function is to obtain the default conversation template for the given model path.
  2. Call get_model_adapter(model_path: str)

    • The get_conversation_template function first calls get_model_adapter, passing in the model path.
    • The purpose of get_model_adapter is to find and return a suitable BaseModelAdapter object for the provided model path.
    • This function first tries to match the basename of model_path. If no match is found, it tries the full path.
    • If a suitable adapter is found, it is returned; otherwise, a ValueError is thrown.
  3. Execute BaseModelAdapter.get_default_conv_template(model_path: str)

    • Once the appropriate model adapter is obtained, get_conversation_template retrieves the default conversation template by calling the get_default_conv_template method of that adapter.
    • Note that this method is defined in the BaseModelAdapter class but might be overridden in subclasses.
  4. Call get_conv_template(name: str)

    • Inside the get_default_conv_template method, it calls the get_conv_template function, usually passing a predefined template name like "one_shot".
    • The purpose of get_conv_template is to retrieve a specified name’s template from the global registry of conversation templates conv_templates.
  5. Obtain and Return a Conversation Object

    • The get_conv_template function returns an instance of the Conversation class, usually copied from the conv_templates dictionary.
    • Finally, this Conversation instance is returned to the original call site of get_conversation_template.

Summarizing the call chain:

1
2
3
4
5
6
7
get_conversation_template(model_path)
-> get_model_adapter(model_path)
-> [BaseModelAdapter].get_default_conv_template(model_path)
-> get_conv_template(name)
-> Return Conversation Object


In this process, the code navigates through a series of function calls to find a suitable model adapter based on the provided model path and retrieve a specific conversation template from it. This design pattern allows flexibility in providing different conversation templates for different models, enhancing the reusability and extensibility of the code.


2. Configuration Classes

These classes are defined using Python’s dataclass decorator and are mainly used for storing configurations and parameters. These classes usually do not contain complex methods or logic but are used to define and store data structures. These classes include:

  • ModelArguments: Stores parameters related to the model, like model path, trust in remote code, etc.
  • DataArguments: Stores parameters related to data, like data path, evaluation data path, and whether to use lazy preprocessing.
  • TrainingArguments: Stores parameters related to training, like cache directory, optimizer type, model maximum length, etc. This class extends transformers.TrainingArguments and adds some custom parameters.

These classes are mainly used to simplify and organize parameter management in the code, making parameter modification and access more convenient.

1. ModelArguments Class

Code

1
2
3
4
5
6
7
8
9
10
11
12
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
trust_remote_code: bool = field(
default=False,
metadata={
"help": "Whether or not to allow for custom models defined on the Hub in their own modeling files"
},
)
padding_side: str = field(
default="right", metadata={"help": "The padding side in tokenizer"}
)

Explanation

ModelArguments is a data class (dataclass) used for storing model-related configuration parameters.
Attributes:

  1. model_name_or_path: Specifies the name or path of the pretrained model.
  2. trust_remote_code: Whether to allow custom models that have their modeling files defined on the Hub.
  3. padding_side: Specifies the padding side in the tokenizer, typically right or left padding.
Introduction to `@dataclass` decorator, click to expand `@dataclass` is a decorator used to automate the generation of special methods like `__init__()`, `__repr__()`, `__eq__()` etc., thus simplifying the writing of data classes. This decorator is part of Python 3.7 and is in the `dataclasses` module.

When you use @dataclass before a class definition, Python automatically adds some special methods based on the fields defined in the class. This is very useful for creating classes that store a small amount of data but do not need complex methods.

Specifically, using @dataclass:

  1. Automatically generates a constructor (__init__ method): Python creates an __init__ method automatically based on the fields defined in the class, so you don’t need to manually write this method to initialize your class instances.

  2. Automatically generates a __repr__ method: This makes printing the class instances provide a more readable string representation, usually including the class name and its fields and their values.

  3. Automatically generates an __eq__ method: This allows you to use the == operator to compare two instances of the class, comparing the values of the instance fields.

  4. Support for type annotations: When defining fields, you can use type annotations, which not only help with clarity of code but can also be checked for type correctness using some tools.

In the case of the ModelArguments class, the @dataclass decorator will generate the above-mentioned methods. This means you can easily create an instance of ModelArguments, and when printing or comparing these instances, you will get the expected behavior.

For example, when you create an instance of ModelArguments:

1
args = ModelArguments()

This will call the automatically generated __init__ method, using the default values “facebook/opt-125m” for model_name_or_path, False for trust_remote_code, and “right” for padding_side.

When you print this instance:

1
print(args)

This will call the automatically generated __repr__ method, showing a detailed view of the class instance, like ModelArguments(model_name_or_path="facebook/opt-125m", trust_remote_code=False, padding_side="right").

Thus, the @dataclass decorator simplifies the process of creating classes, making the code more concise and maintainable.

Overall, the @dataclass decorator is a convenient tool provided by Python for quickly creating classes mainly used for storing data.

2. DataArguments Class

Code

1
2
3
4
5
6
7
8
9
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
lazy_preprocess: bool = False

Explanation

DataArguments Class

  • DataArguments is also a data class used for storing data-related configuration parameters.
  • Attributes:
    • data_path: Path to the training data.
    • eval_data_path: Path to the evaluation data.
    • lazy_preprocess: Whether to use lazy loading for data preprocessing, i.e., load and process data as needed.

3. TrainingArguments Class

Code

1
2
3
4
5
6
7
8
9
10
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)

Explanation

TrainingArguments class extends transformers.TrainingArguments.

  1. TrainingArguments Class

    • TrainingArguments is a data class that, by extending transformers.TrainingArguments, gains the capability to handle training parameters.
    • Attributes defined in TrainingArguments:
      • cache_dir: Specifies the directory path for caching the model and tokenizer.
      • optim: Defines the type of optimizer to use, like 'adamw_torch'.
      • model_max_length: Specifies the maximum sequence length the model can handle.
  2. transformers.TrainingArguments Class

    • transformers.TrainingArguments is a class in the transformers library that is used for configuring various parameters in the model training process.
    • This class contains a plethora of attributes for controlling the training process, such as:
      • output_dir: Specifies the directory to save the model and training results.
      • num_train_epochs: Number of training epochs.
      • per_device_train_batch_size: Batch size per device for training.
      • save_steps: Steps interval for saving the model.
      • evaluation_strategy: Strategy for evaluating the model, like at the end of each epoch.
      • learning_rate: Learning rate.
      • warmup_steps: Steps used for warmup in the learning rate schedule.
    • transformers.TrainingArguments also

contains many other parameters for fine-tuning the training process, including logging, model saving strategies, learning rate scheduling, and more.

By extending transformers.TrainingArguments, the TrainingArguments class not only inherits all these training parameter configurations but can also add some custom training parameters, like in this case cache_dir, optim, and model_max_length. This approach enhances code reusability and flexibility, allowing you to adjust and extend training configurations as per the specific requirements of your project.

3. Functional Utility Functions

1. rank0_print(*args)

Code

1
2
3
4
5
local_rank = None

def rank0_print(*args):
if local_rank == 0:
print(*args)

Explanation

Defines a global variable local_rank for distributed training.
Defines a function rank0_print to print information only if local_rank is 0, used for controlling output in distributed training. This way, repetitive printing of the same information across multiple nodes is avoided, making the output clearer and more concise.

  • Used to print information only on the main node (rank 0) in a distributed training environment.
  • Parameters: A variable number of arguments for printing.

2. trainer_save_model_safe(trainer: transformers.Trainer)

Code

1
2
3
4
5
6
7
8
9
def trainer_save_model_safe(trainer: transformers.Trainer):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
):
trainer.save_model()

The function trainer_save_model_safe(trainer: transformers.Trainer) aims to safely save models trained with the PyTorch distributed framework. Let’s delve into the details of this function and its key components.

Explanation

  1. Parameters:

    • trainer: An instance of transformers.Trainer. This class is one of the core components of the Hugging Face Transformers library, used for training and evaluating models.
  2. Functionality:

    • The main purpose of this function is to safely save models in a distributed training environment. It particularly considers the model saving strategy when using Fully Sharded Data Parallel (FSDP).
  3. FSDP

    • FullyShardedDataParallel (FSDP)
      • This is a component of PyTorch’s distributed training framework. FSDP helps reduce memory usage on each GPU by sharding model parameters across multiple GPUs, allowing the training of larger models.
      • In this context, FSDP is primarily used for handling and saving model states in distributed training.
    • StateDictType
      • This is an enumeration type that defines how to save the model’s state dictionary. In FSDP environments, saving and loading model states might require special handling.
    • FullStateDictConfig
      • This class configures parameters for saving the full state dictionary. It’s part of FSDP’s functionality and is used to control how the model state is saved.
  4. Function Implementation

    • Setting Save Policy
      • save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) creates a save policy. Here, two key parameters are specified:
        • offload_to_cpu: Offload model parameters to CPU before saving the state dictionary, which helps reduce GPU memory usage.
        • rank0_only: Save the model only on rank 0 (usually the main node). In distributed training, this avoids saving the same model copy on every node, saving storage space.
    • Saving the Model
      • Using the with FSDP.state_dict_type(trainer.model, StateDictType.FULL_STATE_DICT, save_policy) context manager, the type and policy for saving the model’s state dictionary are set.
      • Within this context, trainer.save_model() is called to save the model. Due to the save_policy, the model is saved securely following the specified configuration.

The function trainer_save_model_safe encapsulates a safe model saving logic, particularly for scenarios involving PyTorch’s FSDP in distributed training. It ensures that only a complete model state is saved on one node and offloads model parameters to CPU before saving, optimizing memory usage and storage efficiency. This is crucial for training large models and managing large-scale distributed training environments.

3.preprocess(sources,tokenizer: transformers.PreTrainedTokenizer) -> Dict

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = get_conversation_template("vicuna")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]

conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())

# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()

assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())

turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)

parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2

if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1

# Ignore the user instructions
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len

if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1

target[cur_len:] = IGNORE_TOKEN_ID

if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))
exit()

if cur_len < tokenizer.model_max_length:


if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
rank0_print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" #turn = {len(turns) - 1}. (ignored)"
)

return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)

The function preprocess(sources, tokenizer: transformers.PreTrainedTokenizer) -> Dict is intended for preprocessing dialogue data to be suitable for training machine learning models. This function can be broken down into several main parts for a more detailed explanation:

1. Obtaining Conversation Templates and Role Definitions

1
2
conv = get_conversation_template("vicuna")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
  • Functionality: Initializes conversation templates and defines the roles of dialogue participants.
  • Implementation:
    • conv = get_conversation_template("vicuna") obtains the conversation template for a specified model (e.g., “vicuna”).
    • The roles dictionary maps “human” and “gpt” to the roles defined in the conversation template.
  • Example:
    • If the conversation template is for “vicuna”, then roles might map “human” to “user” and “gpt” to “assistant”. For example, {'human': 'USER', 'gpt': 'ASSISTANT'}.

2. Applying Prompt Templates

1
2
3
4
5
6
7
8
9
10
11
12
13
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]

conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
  • Functionality: Applies prompt templates to source data to construct dialogues.
  • Implementation:
    • Iterates through sources (original dialogue data), transforming each dialogue source into a conversation in template format.
    • If the first part of a dialogue is not initiated by the “human” role, it skips that part.
    • Assigns a role to each sentence and adds it to the conversation template.
    • Ultimately, each processed dialogue is added to the conversations list.
  • Example:
    • Suppose we have a source which is the first item in dummy input: python source = [{'from': 'human', 'value': 'Who are you?'}, {'from': 'gpt', 'value': 'I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).'}, {'from': 'human', 'value': 'Have a nice day!'}, {'from': 'gpt', 'value': 'You too!'}]
    • conversations under the Vicuna template, using SeparatorStyle.ADD_COLON_TWO as the separator style, might look like [“A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user’s questions. USER: Who are you? ASSISTANT: I am Vicuna, a language model trained by researchers from Large Model Systems Organization (LMSYS).USER: Have a nice day! ASSISTANT: You too!“]
    • Implementation of get_prompt The `get_prompt` method implementation varies depending on the `SeparatorStyle`. Below is a table detailing the `get_prompt` method for various styles, along with English examples:
      Separator Style (SeparatorStyle) Description Example
      ADD_COLON_SINGLE Adds a colon and separator after each message. USER: Hello there!\nASSISTANT: Hi, how can I help?\n
      ADD_COLON_TWO Uses two alternating separators, usually between different roles. USER: What’s the weather?\nASSISTANT: It’s sunny today.\n\n
      ADD_COLON_SPACE_SINGLE Adds a colon, space, and separator after each message. USER: Can you book a flight?\nASSISTANT: Sure, where to?\n
      NO_COLON_SINGLE Messages directly follow roles without a colon, followed by a separator. USERWhat are you doing?\nASSISTANTI’m here to assist you.\n
      NO_COLON_TWO No colons, with two alternating separators. USERHow’s the project going?\nASSISTANTIt’s on track.\n\n
      ADD_NEW_LINE_SINGLE Each message is preceded by a newline, followed by a separator. USER\nHow can I reset my password?\nASSISTANT\nYou can reset it via email.\n
      RWKV Special format, usually for specific models. USER: What is AI?\n\nASSISTANT: AI stands for Artificial Intelligence.\n\n
      LLAMA2 Special label format for specific models. [INST] USER How does blockchain work?\nASSISTANT It is a distributed ledger.\n\n
      CHATGLM Specific format for CHATGLM model. [Round 1]\nUSER: Tell me a joke.\nASSISTANT: Why did the chicken cross the road?\n
      CHATML Similar to CHATGLM, but with newlines before and after each message. USER\nDo you like music?\n\nASSISTANT\nYes, I enjoy many genres.\n\n
      CHATGLM3 Format for CHATGLM3 model. USER\nCan you play chess?\nASSISTANTYes, I can play.\n
      CHATINTERN Format for CHATINTERN model, using special markers. USER:Where is the nearest ATM?\nASSISTANT:It’s next to the post office.\n
      DOLLY Specific format for DOLLY model. USER:\nWhat is quantum computing?\nASSISTANT:\nIt involves computation using quantum-mechanical phenomena.\n\n
      PHOENIX For PHOENIX model, messages are wrapped in special markers. USER: How to bake a cake?\nASSISTANT: You need flour, sugar, and eggs.\n
      ROBIN Similar to ADD_NEW_LINE_SINGLE, but with a newline after roles. USER:\nIs AI dangerous?\nASSISTANT:\nIt depends on how it’s used.\n